-
Notifications
You must be signed in to change notification settings - Fork 111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(WIP) Batched autodiff #2181
base: main
Are you sure you want to change the base?
(WIP) Batched autodiff #2181
Conversation
@@ -27,7 +27,11 @@ getFunctionTypeForClone(mlir::FunctionType FTy, DerivativeMode mode, | |||
for (auto &&[Ty, returnPrimal, returnShadow, activity] : llvm::zip( | |||
FTy.getResults(), returnPrimals, returnShadows, ReturnActivity)) { | |||
if (returnPrimal) { | |||
RetTypes.push_back(Ty); | |||
if (width != 1) { | |||
RetTypes.push_back(mlir::RankedTensorType::get({width}, Ty)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn’t need changing since the primal is always unmodified, only Derivatives are changed (and we should be pushing the getshadow types for those below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, then I'm confused of what batched autodiff is.
How should my testcase change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nvm, it clicked. It's just the shadow that's batched 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so here's an example from llvm vector mode for example: https://github.com/EnzymeAD/Enzyme/blob/main/enzyme/test/Enzyme/ForwardModeVector/add.ll
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tho perhaps mul will be more illustrative, https://github.com/EnzymeAD/Enzyme/blob/main/enzyme/test/Enzyme/ForwardModeVector/mul.ll (and obviously feel free to look at any/all of the other examples
I haven't yet fully made the changes in enzyme-tblgen.cpp, and either way this just works for the simple test case. mlir::Value itmp = ({
// Computing MulFOp
auto fwdarg_0 = dif;
auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1));
if (gutils->width != 1)
{
fwdarg_1 = builder.create<tensor::SplatOp>(
op.getLoc(),
mlir::RankedTensorType::get({gutils->width},
fwdarg_1.getType()),
fwdarg_1);
}
builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1);
}); But this is the MLIR code that is generated for this simple test: func.func private @fwddiffe2square(%arg0: f64, %arg1: tensor<2xf64>) -> tensor<2xf64> {
%splat = tensor.splat %arg0 : tensor<2xf64>
%0 = arith.mulf %arg1, %splat : tensor<2xf64>
%splat_0 = tensor.splat %arg0 : tensor<2xf64>
%1 = arith.mulf %arg1, %splat_0 : tensor<2xf64>
%2 = arith.addf %0, %1 : tensor<2xf64>
%3 = arith.mulf %arg0, %arg0 : f64
return %2 : tensor<2xf64>
} |
This still requires changes in the tblgenerated derivative files. For example, createForwardModeTangent in MulFOpFwdDerivative could be altered like this: ``` LogicalResult createForwardModeTangent(Operation *op0, OpBuilder &builder, MGradientUtils *gutils) const { auto op = cast<arith::MulFOp>(op0); if (gutils->width != 1) { auto newop = gutils->getNewFromOriginal(op0); for (auto res : newop->getResults()) { res.setType(mlir::RankedTensorType::get({gutils->width}, res.getType())); } } gutils->eraseIfUnused(op); if (gutils->isConstantInstruction(op)) return success(); mlir::Value res = nullptr; if (!gutils->isConstantValue(op->getOperand(0))) { auto dif = gutils->invertPointerM(op->getOperand(0), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); // TODO: gutils->makeBatched(...) auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(1)); builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1); }); itmp.dump(); if (!res) res = itmp; else { auto operandType = cast<AutoDiffTypeInterface>(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } if (!gutils->isConstantValue(op->getOperand(1))) { auto dif = gutils->invertPointerM(op->getOperand(1), builder); { mlir::Value itmp = ({ // Computing MulFOp auto fwdarg_0 = dif; dif.dump(); auto fwdarg_1 = gutils->getNewFromOriginal(op->getOperand(0)); builder.create<arith::MulFOp>(op.getLoc(), fwdarg_0, fwdarg_1); }); if (!res) res = itmp; else { auto operandType = cast<AutoDiffTypeInterface>(res.getType()); res = operandType.createAddOp(builder, op.getLoc(), res, itmp); } } } assert(res); gutils->setDiffe(op->getResult(0), res, builder); return success(); } ```
This reverts commit c06ed01.
NOTE: Only works for scalars and *ranked* tensors for now. | ||
}]; | ||
|
||
let arguments = (ins AnyType:$input, I64Attr:$width); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To support the Enzyme.batch [which is a bit more general since it takes a shape, not just a single int], do we want to make this a vararg of i64 ?
enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Outdated
void BroadcastOp::build(OpBuilder &builder, OperationState &result, Value input, llvm::SmallVector<int64_t> shape) { | ||
auto shapeAttr = builder.getDenseI64ArrayAttr(shape); | ||
RankedTensorType output; | ||
// TODO: support things other than scalars and ranked tensors, maybe reuse getShadowType here? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah long term we should probably just use
/*methodName=*/"getShadowType", |
so essentially do
auto ty = input.getType();
for (auto s : reverse(shape)) {
ty = ty.cast().getShadowType(s)
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but for now this is fine
fix the format/etc then I think this is good to go! |
Added some type conversions to tensor types if
width != 1
. The simple test case seems correct now.Corresponding Enzyme-JAX PR: EnzymeAD/Enzyme-JAX#197